# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A deep MNIST classifier using convolutional layers.

This example was adapted from
https://pytorch.org/docs/master/distributed.html
https://pytorch.org/tutorials/intermediate/dist_tuto.html
https://github.com/narumiruna/pytorch-distributed-example/blob/master/mnist/main.py

Each worker reads the full MNIST dataset and asynchronously trains a CNN with dropout and
using the Adam optimizer, updating the model parameters on shared parameter servers.

The current training accuracy is printed out after every 100 steps.
"""


import argparse
import os

import torch
import torch.nn.functional as F
from torch import distributed, nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms


class AverageMeter:
    def __init__(self):
        self.sum = 0
        self.count = 0

    def update(self, value, number):
        self.sum += value * number
        self.count += number

    @property
    def average(self):
        return self.sum / self.count


class AccuracyMeter:
    def __init__(self):
        self.correct = 0
        self.count = 0

    def update(self, output, label):
        predictions = output.data.argmax(dim=1)
        correct = predictions.eq(label.data).sum().item()

        self.correct += correct
        self.count += output.size(0)

    @property
    def accuracy(self):
        return self.correct / self.count


class Trainer:
    def __init__(self, net, optimizer, train_loader, test_loader, device):
        self.net = net
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device

    def train(self):
        train_loss = AverageMeter()
        train_acc = AccuracyMeter()

        self.net.train()

        for data, label in self.train_loader:
            data = data.to(self.device)
            label = label.to(self.device)

            output = self.net(data)
            loss = F.cross_entropy(output, label)

            self.optimizer.zero_grad()
            loss.backward()
            # average the gradients
            self.average_gradients()
            self.optimizer.step()

            train_loss.update(loss.item(), data.size(0))
            train_acc.update(output, label)

        return train_loss.average, train_acc.accuracy

    def evaluate(self):
        test_loss = AverageMeter()
        test_acc = AccuracyMeter()

        self.net.eval()

        with torch.no_grad():
            for data, label in self.test_loader:
                data = data.to(self.device)
                label = label.to(self.device)

                output = self.net(data)
                loss = F.cross_entropy(output, label)

                test_loss.update(loss.item(), data.size(0))
                test_acc.update(output, label)

        return test_loss.average, test_acc.accuracy

    def average_gradients(self):
        world_size = distributed.get_world_size()

        for p in self.net.parameters():
            group = distributed.new_group(ranks=list(range(world_size)))

            tensor = p.grad.data.cpu()

            distributed.all_reduce(tensor, op=distributed.reduce_op.SUM, group=group)

            tensor /= float(world_size)

            p.grad.data = tensor.to(self.device)


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(784, 10)

    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))


def get_dataloader(root, batch_size):
    transform = transforms.Compose(
        # https://github.com/psf/black/issues/2434
        # fmt: off
        [transforms.ToTensor(),
         transforms.Normalize((0.13066047740239478,), (0.3081078087569972,))]
        # fmt: on
    )

    train_set = datasets.MNIST(root, train=True, transform=transform, download=True)
    sampler = DistributedSampler(train_set)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=(sampler is None), sampler=sampler)

    test_loader = DataLoader(
        datasets.MNIST(root, train=False, transform=transform, download=True),
        batch_size=batch_size,
        shuffle=False,
    )

    return train_loader, test_loader


def solve(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    net = Net().to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)

    train_loader, test_loader = get_dataloader(args.root, args.batch_size)

    trainer = Trainer(net, optimizer, train_loader, test_loader, device)

    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc = trainer.train()
        test_loss, test_acc = trainer.evaluate()

        print(
            f"Epoch: {epoch}/{args.epochs},",
            "train loss: {:.6f}, train acc: {:.6f}, test loss: {:.6f}, test acc: {:.6f}.".format(
                train_loss, train_acc, test_loss, test_acc
            ),
        )


def init_process(args):
    distributed.init_process_group(
        backend=args.backend,
        init_method=args.init_method,
        rank=args.rank,
        world_size=args.world_size,
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--backend", type=str, default="tcp", help="Name of the backend to use.")
    parser.add_argument(
        "--init-method",
        "-i",
        type=str,
        default=os.environ.get("INIT_METHOD", "tcp://127.0.0.1:23456"),
        help="URL specifying how to initialize the package.",
    )
    parser.add_argument(
        "--rank",
        "-r",
        type=int,
        default=int(os.environ.get("RANK")),
        help="Rank of the current process.",
    )
    parser.add_argument(
        "--world-size",
        "-s",
        type=int,
        default=int(os.environ.get("WORLD")),
        help="Number of processes participating in the job.",
    )
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--no-cuda", action="store_true")
    parser.add_argument("--learning-rate", "-lr", type=float, default=1e-3)
    parser.add_argument("--root", type=str, default="data")
    parser.add_argument("--batch-size", type=int, default=128)
    args = parser.parse_args()
    args.cuda = torch.cuda.is_available() and not args.no_cuda
    print(args)

    init_process(args)
    solve(args)


if __name__ == "__main__":
    main()
